#
# Copyright 2020 The XLS Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pylint: disable=missing-function-docstring
"""Trains a model that predicts delay from a sequence of operations.

Training data is generated by xls/experimental/ml_delay_model/expr_generator.py.
Model input is a vector of one-hot encoded opcodes in reverse polish notation,
and outputs predicted delay in ns.
"""

import argparse
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

NUM_OPCODES = 8
MAX_OP_COUNT = 8
HIDDEN_LAYER1 = 250
HIDDEN_LAYER2 = 250
TWO_HIDDEN_LAYERS = True
NUM_EPOCHS = 1
EPSILON = 0.2
NUM_EPS = 5


class Net(nn.Module):
  """NN model."""

  def __init__(self):
    super(Net, self).__init__()
    self.fc1 = nn.Linear(NUM_OPCODES * MAX_OP_COUNT, HIDDEN_LAYER1)
    if TWO_HIDDEN_LAYERS:
      self.fc2 = nn.Linear(HIDDEN_LAYER1, HIDDEN_LAYER2)
      self.out = nn.Linear(HIDDEN_LAYER2, 1)
    else:
      self.out = nn.Linear(HIDDEN_LAYER1, 1)

  def forward(self, x):
    x = F.relu(self.fc1(x))
    if TWO_HIDDEN_LAYERS:
      x = F.relu(self.fc2(x))
    x = self.out(x)
    return x


model = Net()
criterion = nn.MSELoss()
l1 = nn.L1Loss()
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)


def train(data):
  model.train()
  idxs = list(range(len(data)))
  random.shuffle(idxs)
  train_loss = 0
  for i in idxs:
    x = data[i][:-1]
    y = data[i][-1]
    x_onehot = [
        [1 if x[a] == j + 1 else 0 for j in range(NUM_OPCODES)]
        for a in range(len(x))
    ]
    x_tensor = torch.flatten(torch.FloatTensor(x_onehot))
    y_tensor = torch.FloatTensor([y])
    # Forward pass
    y_pred = model(x_tensor)
    loss = criterion(y_pred, y_tensor)
    model.zero_grad()
    # Backward pass
    loss.backward()
    # Perform step
    optimizer.step()
    train_loss += loss
  return train_loss / len(data)


def test(data, log_misses=False):
  model.eval()
  total = 0
  ep = [0 for i in range(NUM_EPS)]
  with open(
      './data/data_{}_{}.log'.format(NUM_OPCODES, MAX_OP_COUNT), 'w+'
  ) as logfile:
    errors = []
    for i in range(len(data)):
      x = data[i][:-1]
      y = data[i][-1]
      x_onehot = [
          [1 if x[a] == j + 1 else 0 for j in range(NUM_OPCODES)]
          for a in range(len(x))
      ]
      x_tensor = torch.flatten(torch.FloatTensor(x_onehot))
      y_tensor = torch.FloatTensor([y])
      y_pred = model(x_tensor)

      loss = criterion(y_pred, y_tensor)
      abs_error = l1(y_pred, y_tensor)
      for j in range(NUM_EPS):
        if abs_error < EPSILON * (j + 1):
          ep[j] += 1
      total += loss
      if log_misses:
        errors.append((data[i][:-1], data[i][-1], y_pred, abs_error))
    if log_misses:
      errors.sort(key=lambda x: x[3], reverse=True)
      for entry in errors:
        print(entry, file=logfile)
    mse = total / len(data)
    accuracy = [x / len(data) for x in ep]
    print('Validation MSE Loss: {}'.format(mse))
    print('Accuracy:', accuracy)
  return mse


def main(log_misses):
  # csv generated by xls/experimental/ml_delay_model/expr_generator.py
  data = np.genfromtxt(
      './data/data_{}_{}.csv'.format(NUM_OPCODES, MAX_OP_COUNT), delimiter=','
  )
  train_size = int(0.8 * len(data))
  train_data, test_data = torch.utils.data.random_split(
      data, [train_size, len(data) - train_size]
  )
  epochs = [i + 1 for i in range(NUM_EPOCHS)]
  training = []
  validation = []
  log_flag = False
  for i in range(NUM_EPOCHS):
    if i == (NUM_EPOCHS - 1) and log_misses:
      log_flag = True

    print('Epoch {}'.format(i))
    train_loss = train(train_data)
    valid_loss = test(test_data, log_flag)
    training.append(train_loss)
    validation.append(valid_loss)
  plt.plot(epochs, training, label='training')
  plt.plot(epochs, validation, label='validation')
  plt.xlabel('Epoch')
  plt.ylabel('MSE Loss')
  plt.legend(loc='upper right')
  plt.savefig('./data/loss_{}_{}'.format(NUM_OPCODES, MAX_OP_COUNT))


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('log_misses', type=bool)
  args = parser.parse_args()
  main(args.log_misses)
